#!/usr/bin/python

### Largely self-contained image binarization and deskewing.  You can easily 
### use this as a basis for other kinds of preprocessing.

import sys,os
import signal
signal.signal(signal.SIGINT,lambda *args:sys.exit(1))
import traceback
import argparse
import multiprocessing

import matplotlib
if "DISPLAY" not in os.environ: matplotlib.use("AGG")
else: matplotlib.use("GTK")
import pylab

# all the image processing code comes from scipy and pylab
from pylab import *
from scipy.ndimage import measurements,interpolation

# ocrolib is only used for image I/O and pathname manipulation
import ocrolib

parser = argparse.ArgumentParser(description = """
Perform document image preprocessing:

- Sauvola binarization
- deskewing
- large and small component removal

Images are processed from the command line and put into a standard book directory,
creating

- book/0001.png (deskewed grayscale page image)
- book/0001.bin.png (deskewed and cleaned binary page image)

This assumes 300dpi images (i.e., all the internal thresholds and constants
are set up for that).  If your image is a different resolution, use the -z (zoom)
argument.

This will work reasonbly well for many kinds of inputs, but
ocropus-nlbin is the preferred binarizer now.
""")

parser.add_argument("files",default=[],nargs='*',help="input lines")

parser.add_argument("-o","--output",help="output directory",default=None)
parser.add_argument("-q","--quiet",action="store_true",help="disable warnings")
parser.add_argument("-Q","--parallel",type=int,default=0,help="number of parallel processes to use")
parser.add_argument("-g","--gtextension",help="ground truth extension for copying in ground truth (include all dots)",default=None)
parser.add_argument("--debug",help="display intermediate results for debugging",action="store_true")
parser.add_argument("--show",help="show binarized output",action="store_true")

# parser.add_argument("--dpi",default=300,type=float,help="resolution (DPI) (300)")
parser.add_argument("-z","--zoom",type=float,default=1.0,help="rescale the image prior to processing")
parser.add_argument("--maxsize",type=int,default=300,help="maximum character component size")
parser.add_argument("--minsize",type=int,default=5,help="minimum character component size")
parser.add_argument("--binarize",action="store_true",help="always run binarization, even if image appears binary already")
parser.add_argument("--invert",action="store_true",help="invert the image prior to binarization")
parser.add_argument("--htrem",action="store_true",help="always remove halftones (even if there don't appear to be any)")
parser.add_argument("--nohtrem",action="store_true",help="never remove halftones (even if there appear to be some)")
parser.add_argument("--uncleaned",action="store_true",help="output only the deskewed binary image with no further cleanup")
parser.add_argument("--noskew",action="store_true",help="do not perform skew correction")

# sauvola
parser.add_argument("-s","--sigma",type=float,default=150,help="sigma arguent for Sauvola binarization")
parser.add_argument("-k","--k",type=float,default=0.3,help="k value for Sauvola binarization")

print 
print "#"*10,(" ".join(sys.argv))[:60]
print 

# hysteresis thresholding
# TBD

args = parser.parse_args()

args.files = ocrolib.glob_all(args.files)

if len(args.files)<1:
    parser.print_help()
    sys.exit(0)

if args.debug or args.show: args.parallel = 1

################################################################
# preprocessing
################################################################

import os,os.path
from pylab import *
from scipy.ndimage import filters

################################################################
### Binarization
################################################################

def is_binary(image):
    """Check whether an input image is binary"""
    return sum(image==amin(image))+sum(image==amax(image)) > 0.99*image.size
    
def gsauvola(image,sigma=150.0,R=None,k=0.3,filter='uniform',scale=2.0):
    """Perform Sauvola-like binarization.  This uses linear filters to
    compute the local mean and variance at every pixel."""
    if image.dtype==dtype('uint8'): image = image / 256.0
    if len(image.shape)==3: image = mean(image,axis=2)
    if filter=="gaussian":
        filter = filters.gaussian_filter
    elif filter=="uniform":
        filter = filters.uniform_filter
    else:
        pass
    scaled = interpolation.zoom(image,1.0/scale,order=0,mode='nearest')
    s1 = filter(ones(scaled.shape),sigma)
    sx = filter(scaled,sigma)
    sxx = filter(scaled**2,sigma)
    avg_ = sx / s1
    stddev_ = maximum(sxx/s1 - avg_**2,0.0)**0.5
    s0,s1 = avg_.shape
    s0 = int(s0*scale)
    s1 = int(s1*scale)
    avg = zeros(image.shape)
    interpolation.zoom(avg_,scale,output=avg[:s0,:s1],order=0,mode='nearest')
    stddev = zeros(image.shape)
    interpolation.zoom(stddev_,scale,output=stddev[:s0,:s1],order=0,mode='nearest')
    if R is None: R = amax(stddev)
    thresh = avg * (1.0 + k * (stddev / R - 1.0))
    return array(255*(image>thresh),'uint8')

def inverse(image):
    return amax(image)-image

################################################################
### Bounding-box operations.
################################################################

def bounding_boxes_math(image):
    """Compute the bounding boxes in the image; returns mathematical
    coordinates."""
    image = (image>mean([amax(image),amin(image)]))
    image,ncomponents = measurements.label(image)
    objects = measurements.find_objects(image)
    result = []
    h,w = image.shape
    for o in objects:
        y1 = h-o[0].start
        y0 = h-o[0].stop
        x0 = o[1].start
        x1 = o[1].stop
        c = (x0,y0,x1,y1)
        result.append(c)
    return result

def estimate_skew_angle(image,angles=linspace(-2.0,2.0,11)):
    estimates = []
    for a in angles:
        v = mean(interpolation.rotate(image,a,order=0,mode='constant'),axis=1)
        v = var(v)
        estimates.append((v,a))
    if args.debug>0:
        plot([y for x,y in estimates],[x for x,y in estimates])
        ginput(1,args.debug)
    _,a = max(estimates)
    return a
    
def check_contains_halftones(image,dpi=300.0):
    """Heuristic method for determining whether we should apply a halftone removal
    algorithm."""
    bboxes = bounding_boxes_math(image)
    r = 4*dpi/300.0
    big = 0
    for b in bboxes:
        x0,y0,x1,y1 = b
        if x1-x0>r or y1-y0>r: big += 1
    return big<0.3*len(bboxes)

def remove_small_components(image,r=3):
    """Remove any connected components that are smaller in both dimension than r"""
    image,ncomponents = measurements.label(image)
    objects = measurements.find_objects(image)
    for i in range(len(objects)):
        o = objects[i]
        if o[0].stop-o[0].start>r: continue
        if o[1].stop-o[1].start>r: continue
        c = image[o]
        c[c==i+1] = 0
    return (image!=0)

def remove_big_components(image,r=100):
    """Remove any connected components that are smaller in any dimension than r"""
    image,ncomponents = measurements.label(image)
    objects = measurements.find_objects(image)
    for i in range(len(objects)):
        o = objects[i]
        if o[0].stop-o[0].start<r and o[1].stop-o[1].start<r: continue
        c = image[o]
        c[c==i+1] = 0
    return (image!=0)

def remove_small_any(image,r=3):
    """Remove both small connected components and small holes."""
    image = remove_small_components(image,r=r)
    image = amax(image)-image
    image = remove_small_components(image,r=r)
    image = amax(image)-image
    return image

def rectangular_cover(image,minsize=5):
    """Cover the set of regions with their bounding boxes.  This is
    an image-to-image transformation."""
    image,ncomponents = measurements.label(image)
    objects = measurements.find_objects(image)
    output = zeros(image.shape)
    for i in range(len(objects)):
        o = objects[i]
        if o[0].stop-o[0].start<minsize: continue
        if o[1].stop-o[1].start<minsize: continue
        output[o] = 1
    return output

def find_halftones(image,dpi=300.0,threshold=0.05,r=5,sigma=15.0,cover=1):
    """Find halftone regions in an image.  First, find small components and
    holes, then smooth their occurrences and threshold, finally compute
    a rectangular cover of the thresholded and smoothed image."""
    filtered = remove_small_any(image,r=r)
    diff = ((image!=0)!=(filtered!=0))
    density = filters.gaussian_filter(1.0*diff,sigma*dpi/300.0)
    if cover:
        return rectangular_cover(density>threshold)
    else:
        return maximum(diff,density>threshold)

def remove_halftones(image,dpi=300.0,threshold=0.05,r=5,sigma=15.0):
    """Perform halftone removal using find_halftones."""
    halftones = find_halftones(image,dpi=dpi,threshold=threshold,r=r,sigma=sigma)
    return maximum(image-amax(image)*halftones,0)

################################################################
### All preprocessing steps put together.
################################################################

def IMDEBUG(image,label):
    if args.debug: 
        cla(); xlabel("half tone removal")
        imshow(cleaned); ginput(1,9999)

def preprocess(raw,title=None):
    if args.debug:
        subplot(121); imshow(raw); ginput(1,0.001)
        if title is not None: xlabel(title)
        subplot(122)

    # zoom if requested

    if args.zoom!=1.0:
        raw = interpolation.zoom(raw,args.zoom,mode='nearest',order=1)
        IMDEBUG(raw,"zoomed")

    # binarize if the image isn't already binary

    if args.parallel<2: print "binarizing"
    if args.binarize or not is_binary(raw):
        bin = gsauvola(raw)
    else:
        bin = array(255*(raw>0.5*(amax(raw)+amin(raw))),'B')
    assert amax(bin)>amin(bin),"something went wrong with binarization"

    # invert or detect inverted scans

    if args.invert:
        raw = inverse(raw)
        bin = inverse(bin)

    IMDEBUG(bin,"binarized and inverted")

    # now clean up for skew estimation

    if args.parallel<2: print "cleaning"
    cleaned = bin

    if args.htrem or (not args.nohtrem and check_contains_halftones(bin)):
        cleaned = remove_halftones(cleaned)
        IMDEBUG(cleaned,"half tone removal")

    if args.minsize>0:
        cleaned = remove_small_components(cleaned,args.minsize)
        IMDEBUG(cleaned,"minsize filtering")

    if args.maxsize<10000:
        cleaned = remove_big_components(cleaned,args.maxsize)
        IMDEBUG(cleaned,"maxsize filtering")

    # perform skew estimation

    if args.noskew: 
        return cleaned,raw
    else:
        if args.parallel<2: print "estimating skew angle"
        a = estimate_skew_angle(cleaned)
        print "got",a
        if args.uncleaned: bin = orig
        if args.parallel<2: print "rotating"
        flat = interpolation.rotate(raw,-a*180/pi,mode='nearest',order=0)
        bin = interpolation.rotate(bin,-a*180/pi,mode='nearest',order=0)
        bin = array(255*(bin>0.5*(amin(bin)+amax(bin))),'B')
        IMDEBUG(bin,"skew correction by %f"%a)
        if args.parallel<2: print "writing"
        return bin,flat

################################################################
### main loop
################################################################

if args.debug or args.show: ion(); show()

files = None

def process1(job):
    fname,count = job

    if args.parallel<2: print "sauvola",fname,count

    image = ocrolib.read_image_gray(fname)
    if amax(image)<=amin(image)+1e-4:
        print fname,"is empty"
        return

    IMDEBUG(image,"image")

    title = None
    if args.debug:
        ion(); clf(); pylab.gray()
        title = "%s %s"%(job,count)
    try:
        bin,flat = preprocess(image,title=title)
    except:
        traceback.print_exc()
        print fname,"failed"
        return

    if args.show:
        clf()
        imshow(bin,cmap=cm.gray)
        draw()

    pylab.gray()
    if args.output is not None:
        ocrolib.write_image_gray(args.output+"/%04d.nrm.png"%count,flat,verbose=1)
        ocrolib.write_image_binary(args.output+"/%04d.bin.png"%count,bin,verbose=1)
    else:
        base,_ = ocrolib.allsplitext(fname)
        ocrolib.write_image_gray(base+".nrm.png",flat,verbose=1)
        ocrolib.write_image_binary(base+".bin.png",bin,verbose=1)


if args.debug>0: args.parallel = 0

if args.output:
    if not os.path.exists(args.output):
        os.mkdir(args.output)

if args.parallel<2:
    for i,f in enumerate(args.files): 
        process1((f,i+1))
else:
    pool = multiprocessing.Pool(processes=args.parallel)
    jobs = []
    for i,f in enumerate(args.files): jobs += [(f,i+1)]
    result = pool.map(process1,jobs)
